-
Notifications
You must be signed in to change notification settings - Fork 433
Add rampup batch size support in MaxText #2535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
57ff3e8 to
842193d
Compare
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
842193d to
82f5edf
Compare
It seems out of quota for free tier. We are going to update the Tier 1, should be better soon. |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
4114fa7 to
b911ca8
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces a batch size ramp-up feature to improve training stability, which is a valuable addition. The implementation is generally clean and follows existing patterns in the codebase. The configuration, data loading, and metric logging components are well-integrated.
🔍 General Feedback
- The refactoring of the sharding logic out of the
DataLoaderand into the training loop is a good improvement for separation of concerns. - The use of a factory function
create_dataloaderis a clean way to handle the conditional creation of theRampUpDataLoader. - I've left a couple of minor suggestions to improve an assertion message and to correct the logic in a unit test.
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Nuojing! LGTM! If you could response to Gemini's comments, that will be great! We could chat more on future deliveries offline, but not a blocker for this PR.
24e0d29 to
5e7f164
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for this PR! let's chat more on remaining tasks in the meeting today. Thank you!
47a1028 to
d3fed21
Compare
d3fed21 to
40e056d
Compare
|
We discussed offline, and agree to merge this feature request first. Leave 2 action items for follow up: 1) learning rate adjustment if needed due to this feature; 2) metrics report like throughput during the batch ramp up. |
3375760 to
9143004
Compare
-- 40e056d by NuojCheng <[email protected]>: add rampup batch size COPYBARA_INTEGRATE_REVIEW=#2535 from AI-Hypercomputer:chengnuojin-rampup-batch 40e056d PiperOrigin-RevId: 827037473
-- 40e056d by NuojCheng <[email protected]>: add rampup batch size COPYBARA_INTEGRATE_REVIEW=AI-Hypercomputer#2535 from AI-Hypercomputer:chengnuojin-rampup-batch 40e056d PiperOrigin-RevId: 827037473
-- 40e056d by NuojCheng <[email protected]>: add rampup batch size COPYBARA_INTEGRATE_REVIEW=AI-Hypercomputer#2535 from AI-Hypercomputer:chengnuojin-rampup-batch 40e056d PiperOrigin-RevId: 827037473
9143004 to
3da3587
Compare
-- 40e056d by NuojCheng <[email protected]>: add rampup batch size COPYBARA_INTEGRATE_REVIEW=AI-Hypercomputer#2535 from AI-Hypercomputer:chengnuojin-rampup-batch 40e056d PiperOrigin-RevId: 827037473
Description
This PR adds support for ramp-up batch size, a feature originally proposed in the GPT-3 paper and implemented in Megatron.
When enabled, the per device batch size starts at a smaller value (
per_device_batch_size_start) and gradually increases (per_device_batch_size_increment) until it reaches the targetper_device_batch_sizeover a specified number oframpup_samples. This can help improve training stability, especially during the initial training phases.This feature introduces four new configuration parameters, which align with the Megatron implementation:
enable_rampup_batch_size: (default:False) Set toTrueto enable the ramp-up feature.per_device_batch_size_start: The per-device batch size to use at the beginning of training.per_device_batch_size_increment: The amount to increase the per-device batch size at each ramp-up step.global_rampup_samples: The total number of samples to process before reaching the full target batch size.The PR includes the following changes:
RampupDataLoader: Adds a newRampupDataLoaderclass that inherits from the baseDataLoader. Its primary responsibility is to truncate the input data to match the correct ramp-up shape for the current training step.pyconfig.pyto register and validate the new ramp-up configuration parameters.data_loader_tests.pyto verify theRampupDataLoader's slicing and increment logic.FIXES: b/452468482
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.